import sys,os
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__),'..'))
sys.path.append(project_root)
sys.path.append(r'../witin_nn/nn')

import torch
import torch.optim as optim
from matplotlib import pyplot as plt
from model import ResNet18
from train_fun import load_dataset, train_runner, test_runner
from utils_get_fixed_point_model import get_fixed_model

from train_fun import compute_full_dataset_quantization_params

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
model = ResNet18().to(device)
trainloader, testloader = load_dataset()

config = [model.layer_config1, model.layer_config2, model.layer_config3, model.layer_config4, model.layer_config5, model.layer_config6]
for cfg in config:
    model.config_layer(cfg, use_quantization=True, scale_x=16, scale_y=16, scale_weight=16, bias_row_N=8, noise_level=0)


file_path = '../model_params/witin/quan/ResNet18_param_' + '4' + '.pth'
fixed_model_path = '../model_params/witin/fixed/ResNet18_param_' + '4' + '_quan.pth'
get_fixed_model(model, 8, file_path, fixed_model_path)
# file_path = '../model_params/float/ResNet18_param_9.pth'
model.load_state_dict(torch.load(fixed_model_path))


'''
file_path = '../model_params/witin/float/ResNet18_param_' + '19' + '.pth'
model.load_state_dict(torch.load(file_path))
fixed_model_path = '../model_params/witin/fixed/ResNet18_param_' + '19' + '_quan.pth'
get_fixed_model(model, 8, file_path, fixed_model_path)
model.load_state_dict(torch.load(fixed_model_path))
'''


# for name, param in model.named_parameters():
#     print(f"Name: {name}, Parameter: {param.data}, Gradient: {param.grad}")
scale, zero_point = compute_full_dataset_quantization_params(trainloader)
print(model.layer_config6)
test_loss, test_acc = test_runner(model, device, testloader, use_quantization=True, scale=scale, zero_point=zero_point)
print("test: test_loss, test_acc", test_loss, test_acc)
